# -.- coding:utf-8 -.-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os.path
import pickle
import random

import cv2
import numpy as np
import tensorflow as tf
from scipy.misc import imread
from scipy.misc import imresize
from scipy.misc import imsave

import align.detect_face
from facenet import facenet

# 训练好的facenet模型参数
modeldir = 'models/20190619-150742'


def align_face(input_dir, output_dir, margin=32, detect_multiple_faces=True, random_order=True, gpu_memory_fraction=0.7):
    output_dir = os.path.expanduser(output_dir)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    dataset = facenet.get_dataset(input_dir)

    print('Creating networks and loading parameters')
    face_size = 160
    with tf.Graph().as_default():
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_memory_fraction)
        sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, log_device_placement=False))
        with sess.as_default():
            pnet, rnet, onet = align.detect_face.create_mtcnn(sess, None)

    minsize = 20  # minimum size of face
    threshold = [0.6, 0.7, 0.7]  # three steps's threshold
    factor = 0.709  # scale factor

    # Add a random key to the filename to allow alignment using multiple processes
    random_key = np.random.randint(0, high=99999)
    bounding_boxes_filename = os.path.join(output_dir, 'bounding_boxes_%05d.txt' % random_key)

    with open(bounding_boxes_filename, "w") as text_file:
        nrof_images_total = 0
        nrof_successfully_aligned = 0
        if random_order:
            random.shuffle(dataset)
        for cls in dataset:
            output_class_dir = os.path.join(output_dir, cls.name)
            if not os.path.exists(output_class_dir):
                os.makedirs(output_class_dir)
                if random_order:
                    random.shuffle(cls.image_paths)
            for image_path in cls.image_paths:
                nrof_images_total += 1
                filename = os.path.splitext(os.path.split(image_path)[1])[0]
                output_filename = os.path.join(output_class_dir, filename + '.png')
                print(image_path)
                if not os.path.exists(output_filename):
                    try:
                        img = imread(image_path)
                    except (IOError, ValueError, IndexError) as e:
                        errorMessage = '{}: {}'.format(image_path, e)
                        print(errorMessage)
                    else:
                        if img.ndim < 2:
                            print('Unable to align "%s"' % image_path)
                            text_file.write('%s\n' % (output_filename))
                            continue
                        if img.ndim == 2:
                            img = facenet.to_rgb(img)
                        img = img[:, :, 0:3]

                        bounding_boxes, _ = align.detect_face.detect_face(img, minsize, pnet, rnet, onet, threshold,
                                                                          factor)
                        nrof_faces = bounding_boxes.shape[0]
                        if nrof_faces > 0:
                            det = bounding_boxes[:, 0:4]
                            det_arr = []
                            img_size = np.asarray(img.shape)[0:2]
                            if nrof_faces > 1:
                                if detect_multiple_faces:
                                    for i in range(nrof_faces):
                                        det_arr.append(np.squeeze(det[i]))
                                else:
                                    bounding_box_size = (det[:, 2] - det[:, 0]) * (det[:, 3] - det[:, 1])
                                    img_center = img_size / 2
                                    offsets = np.vstack([(det[:, 0] + det[:, 2]) / 2 - img_center[1],
                                                         (det[:, 1] + det[:, 3]) / 2 - img_center[0]])
                                    offset_dist_squared = np.sum(np.power(offsets, 2.0), 0)
                                    index = np.argmax(
                                        bounding_box_size - offset_dist_squared * 2.0)  # some extra weight on the centering
                                    det_arr.append(det[index, :])
                            else:
                                det_arr.append(np.squeeze(det))

                            for i, det in enumerate(det_arr):
                                det = np.squeeze(det)
                                bb = np.zeros(4, dtype=np.int32)
                                bb[0] = np.maximum(det[0] - margin / 2, 0)
                                bb[1] = np.maximum(det[1] - margin / 2, 0)
                                bb[2] = np.minimum(det[2] + margin / 2, img_size[1])
                                bb[3] = np.minimum(det[3] + margin / 2, img_size[0])
                                cropped = img[bb[1]:bb[3], bb[0]:bb[2], :]
                                scaled = imresize(cropped, (face_size, face_size), interp='bilinear')
                                nrof_successfully_aligned += 1
                                filename_base, file_extension = os.path.splitext(output_filename)
                                if detect_multiple_faces:
                                    output_filename_n = "{}_{}{}".format(filename_base, i, file_extension)
                                else:
                                    output_filename_n = "{}{}".format(filename_base, file_extension)
                                imsave(output_filename_n, scaled)
                                text_file.write('%s %d %d %d %d\n' % (output_filename_n, bb[0], bb[1], bb[2], bb[3]))
                        else:
                            print('Unable to align "%s"' % image_path)
                            text_file.write('%s\n' % (output_filename))

    sess.close()
    print('Total number of images: %d' % nrof_images_total)
    print('Number of successfully aligned images: %d' % nrof_successfully_aligned)


def create_embedding(src, embedding_file, gpu_memory_fraction=1):
    files_set = []
    for name in os.listdir(src):
        _p = os.path.join(src, name)
        if os.path.isfile(_p):
            continue
        for img_f in os.listdir(_p):
            _p1 = os.path.join(_p, img_f)
            if os.path.isdir(_p1):
                continue
            files_set.append((name, _p1))
    face_size = 160
    with tf.Graph().as_default():
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_memory_fraction)
        with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, log_device_placement=False)) as sess:
            facenet.load_model(modeldir, sess)
            images_placeholder = tf.get_default_graph().get_tensor_by_name("input:0")
            embeddings = tf.get_default_graph().get_tensor_by_name("embeddings:0")
            phase_train_placeholder = tf.get_default_graph().get_tensor_by_name("phase_train:0")

            print('facenet embedding模型建立完毕')
            embedding_set = []
            for name, path in files_set:
                image1 = imread(path, mode='RGB')
                image1 = cv2.resize(image1, (face_size, face_size), interpolation=cv2.INTER_CUBIC)
                image1 = facenet.prewhiten(image1)
                image1 = image1.reshape(-1, face_size, face_size, 3)

                embedding_ = sess.run(embeddings, feed_dict={images_placeholder: image1, phase_train_placeholder: False})[0]
                embedding_set.append((name, embedding_))
            print(len(embedding_set))
            with open(embedding_file, "wb") as fp:
                pickle.dump(embedding_set, fp)
            print("done")


if __name__ == '__main__':
    # origin_faces_path = "F:\\photos"            # 原始照片路径
    aligned_faces_path = "F:\\photos1"          # 存放裁剪后的人脸图片路径
    embedding_file = "./models/embedding.pk"   # 人脸编码后的字典序列化文件路径

    # 人脸对齐和编码可以分两次做，便于检查人脸对齐的效果
    # 人脸对齐
    # align_face(origin_faces_path, aligned_faces_path)

    # 创建编码后的人脸数据
    create_embedding(aligned_faces_path, embedding_file)

